{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GgclUjr3sT_E"
   },
   "source": [
    "# IDAKLU-JAX interface\n",
    "\n",
    "The IDAKLU-JAX interface requires that PyBaMM is installed with the [optional JAX solver enabled](https://docs.pybamm.org/en/stable/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver) (`pip install pybamm[jax]`) and requires at least Python 3.10.\n",
    "\n",
    "PyBaMM provides two mechanisms to interface battery models with JAX. The first (JaxSolver) implements PyBaMM models directly in native JAX, and as such provides the greatest flexibility. However, these models can be very slow to compile, especially during their initial run, and can require large amounts of memory.\n",
    "\n",
    "The second (the IDAKLU-Jax interface) instead provides a JAX-compliant interface to the IDAKLU solver. IDAKLU is a fast (compiled) solver based on SUNDIALS. By exposing the IDAKLU solver to JAX, we provide a fast solver capable of interfacing with third-party JAX-compatible software libraries, such as numpyro.\n",
    "\n",
    "Despite the apparent advantages, there are some limitations to this approach. The most notable is that model derivatives are limited to first-order (i.e. sensitivities), since the IDAKLU solver is not capable of auto-differentiation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zsJLlejtzjcC"
   },
   "source": [
    "## Setup a basic DFN model\n",
    "\n",
    "To demonstrate use of the IDAKLU-Jax interface, we first set-up a basic model, choosing the DFN model in this case. We will provide two `inputs` to the model and will specify a list of variables of interest (`output_variables`). Specifying `output_variables` is strongly recommended to reduce computational load, while `inputs` are only required when derivatives are to be considered."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install \"pybamm[jax]\" -q    # install PyBaMM with JAX support if it is not installed\n",
    "import time\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "\n",
    "import pybamm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "99tT6F73sRyc"
   },
   "outputs": [],
   "source": [
    "# We will want to differentiate our model, so let's define two input parameters\n",
    "inputs = {\n",
    "    \"Current function [A]\": 0.222,\n",
    "    \"Separator porosity\": 0.3,\n",
    "}\n",
    "\n",
    "# Set-up the model\n",
    "model = pybamm.lithium_ion.DFN()\n",
    "geometry = model.default_geometry\n",
    "param = model.default_parameter_values\n",
    "param.update({key: \"[input]\" for key in inputs.keys()})\n",
    "param.process_geometry(geometry)\n",
    "param.process_model(model)\n",
    "var = pybamm.standard_spatial_vars\n",
    "var_pts = {var.x_n: 20, var.x_s: 20, var.x_p: 20, var.r_n: 10, var.r_p: 10}\n",
    "mesh = pybamm.Mesh(geometry, model.default_submesh_types, var_pts)\n",
    "disc = pybamm.Discretisation(mesh, model.default_spatial_methods)\n",
    "disc.process_model(model)\n",
    "\n",
    "# Use a short time-vector for this example, and declare which variables to track\n",
    "t_eval = np.linspace(0, 360, 10)\n",
    "output_variables = [\n",
    "    \"Voltage [V]\",\n",
    "    \"Current [A]\",\n",
    "    \"Time [min]\",\n",
    "]\n",
    "\n",
    "# Create the IDAKLU Solver object\n",
    "idaklu_solver = pybamm.IDAKLUSolver(\n",
    "    rtol=1e-6,\n",
    "    atol=1e-6,\n",
    "    output_variables=output_variables,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QhMDaAjt0DR9"
   },
   "source": [
    "Next, we jaxify the IDAKLU solver in the same way that we would run the IDAKLU solve. The only difference is that the `jaxify()` function returns an `IDAKLUJax` object, instead of a `Solution` object. We will keep track of this object, and can request a JAX-expression from it using the `get_jaxpr()` method, as below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "rk4RYT2-z6BD"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "JAX expression: <function IDAKLUJax._jaxify.<locals>.f at 0x132d50f40>\n"
     ]
    }
   ],
   "source": [
    "# This is how we would normally perform a solve using IDAKLU\n",
    "sim = idaklu_solver.solve(\n",
    "    model,\n",
    "    t_eval,\n",
    "    inputs=inputs,\n",
    "    calculate_sensitivities=True,\n",
    ")\n",
    "\n",
    "# Instead, we Jaxify the IDAKLU solver using similar arguments...\n",
    "jax_solver = idaklu_solver.jaxify(\n",
    "    model,\n",
    "    t_eval,\n",
    ")\n",
    "\n",
    "# ... and then obtain a JAX expression for the solve\n",
    "f = jax_solver.get_jaxpr()\n",
    "print(f\"JAX expression: {f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Wjp4Fpj402Ah"
   },
   "source": [
    "The JAX expression (that we named `f` in our example), is a function that can be used and evaluated like any other native JAX expression. This means that it can be included in broader JAX expressions, and can even be JIT compiled. The only limitations are that:\n",
    "1) derivatives cannot be taken beyond first-order, which is the limit of our IDAKLU solver implementation, and\n",
    "2) you are required to specify `output_variables` either at the `IDAKLUSolver` stage, or at the `jaxify` stage (you can create many jaxified expressions from a single solver object).\n",
    "\n",
    "Here is the most basic usage example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "VCKYxXMD0xTX"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[3.81930814e+000 2.22000000e-001 4.95024341e-316]\n",
      " [3.81346107e+000 2.22000000e-001 6.66666667e-001]\n",
      " [3.81080090e+000 2.22000000e-001 1.33333333e+000]\n",
      " [3.80885531e+000 2.22000000e-001 2.00000000e+000]\n",
      " [3.80714541e+000 2.22000000e-001 2.66666667e+000]\n",
      " [3.80552362e+000 2.22000000e-001 3.33333333e+000]\n",
      " [3.80393909e+000 2.22000000e-001 4.00000000e+000]\n",
      " [3.80237338e+000 2.22000000e-001 4.66666667e+000]\n",
      " [3.80081962e+000 2.22000000e-001 5.33333333e+000]\n",
      " [3.79927489e+000 2.22000000e-001 6.00000000e+000]]\n"
     ]
    }
   ],
   "source": [
    "# Print all output variables, evaluated over a given time vector\n",
    "data = f(t_eval, inputs)\n",
    "print(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tUZurVD26t4Z"
   },
   "source": [
    "Here we see a matrix of (Nx3), where N is the number of time-samples in `t_eval`, and the three column-vectors correspond to our three `output_variables`. We can evaluate the expression at any point within our time-span (e.g. `f(0.0, inputs)`), or at multiple points (such as the full range of `t_eval`, as in our example). To help isolate output variables, the IDAKLU-Jax interface provides several helper functions. Below we demonstrate isolating a single variable using the `get_var` helper. You can also isolate multiple variables, provided as a list, by using the `get_vars` helper function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "7UnY6goK633s"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Isolating a single variable returns an array of shape (10,)\n",
      "[3.81930814 3.81346107 3.8108009  3.80885531 3.80714541 3.80552362\n",
      " 3.80393909 3.80237338 3.80081962 3.79927489]\n",
      "\n",
      "Isolating two variables returns an array of shape (10, 2)\n",
      "[[3.81930814 0.222     ]\n",
      " [3.81346107 0.222     ]\n",
      " [3.8108009  0.222     ]\n",
      " [3.80885531 0.222     ]\n",
      " [3.80714541 0.222     ]\n",
      " [3.80552362 0.222     ]\n",
      " [3.80393909 0.222     ]\n",
      " [3.80237338 0.222     ]\n",
      " [3.80081962 0.222     ]\n",
      " [3.79927489 0.222     ]]\n"
     ]
    }
   ],
   "source": [
    "# Isolate a single variables\n",
    "data = jax_solver.get_var(\"Voltage [V]\")(t_eval, inputs)\n",
    "print(f\"Isolating a single variable returns an array of shape {data.shape}\")\n",
    "print(data)\n",
    "\n",
    "# Isolate two variables from the solver\n",
    "data = jax_solver.get_vars(\n",
    "    [\n",
    "        \"Voltage [V]\",\n",
    "        \"Current [A]\",\n",
    "    ],\n",
    ")(t_eval, inputs)\n",
    "print(f\"\\nIsolating two variables returns an array of shape {data.shape}\")\n",
    "print(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IK32lBj9_rcW"
   },
   "source": [
    "As with any JAX expression, we can create new expressions by encapsulating them in outer functions (as further demonstrated below). The method `jax_solver.get_var()` does this for you by encapsulating `f` with a function that isolates a given variable of interest. We then evaluate that new expression by passing our usual arguments `(t_eval, inputs)`.\n",
    "\n",
    "To compute the Jacobian matrix (the matrix of derivates of output variables with respect to each input parameter), make use of the Jacobian forward derivation `jax.jacfwd` and Jacobian reverse derivation `jax.jacrev` functions.\n",
    "\n",
    "When calling these functions we note that `argnums=1` signifies that we are taking the Jacobian with respect to the second argument (indexing from 0: `inputs`). Since `inputs` is a dictionary of input parameters, the result will also be a dictionary of derivatives with respect to each dictionary key / input parameter. These two methods (`jacfwd` and `jacrev`) will produce the same output, it is simply their derivation that differs. In general, the forward method tends to be slightly faster to run than the reverse method for our IDAKLU implementation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "PmPfHSRu8N-_"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Jacobian forward method ran in 0.125 secs\n",
      "{'Current function [A]': Array([[-0.13643792,  1.        ,  0.        ],\n",
      "       [-0.16400861,  1.        ,  0.        ],\n",
      "       [-0.17630142,  1.        ,  0.        ],\n",
      "       [-0.18509421,  1.        ,  0.        ],\n",
      "       [-0.19273301,  1.        ,  0.        ],\n",
      "       [-0.19993145,  1.        ,  0.        ],\n",
      "       [-0.20692727,  1.        ,  0.        ],\n",
      "       [-0.21380043,  1.        ,  0.        ],\n",
      "       [-0.22057579,  1.        ,  0.        ],\n",
      "       [-0.2272616 ,  1.        ,  0.        ]], dtype=float64), 'Separator porosity': Array([[0.00579553, 0.        , 0.        ],\n",
      "       [0.00797   , 0.        , 0.        ],\n",
      "       [0.0095281 , 0.        , 0.        ],\n",
      "       [0.01024868, 0.        , 0.        ],\n",
      "       [0.01053737, 0.        , 0.        ],\n",
      "       [0.0106461 , 0.        , 0.        ],\n",
      "       [0.01068649, 0.        , 0.        ],\n",
      "       [0.01070164, 0.        , 0.        ],\n",
      "       [0.01070816, 0.        , 0.        ],\n",
      "       [0.01071172, 0.        , 0.        ]], dtype=float64)}\n",
      "\n",
      "Jacobian reverse method ran in 0.196 secs\n",
      "{'Current function [A]': Array([[-0.13643792,  1.        ,  0.        ],\n",
      "       [-0.16400861,  1.        ,  0.        ],\n",
      "       [-0.17630142,  1.        ,  0.        ],\n",
      "       [-0.18509421,  1.        ,  0.        ],\n",
      "       [-0.19273301,  1.        ,  0.        ],\n",
      "       [-0.19993145,  1.        ,  0.        ],\n",
      "       [-0.20692727,  1.        ,  0.        ],\n",
      "       [-0.21380043,  1.        ,  0.        ],\n",
      "       [-0.22057579,  1.        ,  0.        ],\n",
      "       [-0.2272616 ,  1.        ,  0.        ]],      dtype=float64, weak_type=True), 'Separator porosity': Array([[0.00579553, 0.        , 0.        ],\n",
      "       [0.00797   , 0.        , 0.        ],\n",
      "       [0.0095281 , 0.        , 0.        ],\n",
      "       [0.01024868, 0.        , 0.        ],\n",
      "       [0.01053737, 0.        , 0.        ],\n",
      "       [0.0106461 , 0.        , 0.        ],\n",
      "       [0.01068649, 0.        , 0.        ],\n",
      "       [0.01070164, 0.        , 0.        ],\n",
      "       [0.01070816, 0.        , 0.        ],\n",
      "       [0.01071172, 0.        , 0.        ]],      dtype=float64, weak_type=True)}\n"
     ]
    }
   ],
   "source": [
    "# Calculate the Jacobian matrix (via forward autodiff)\n",
    "t_start = time.time()\n",
    "out = jax.jacfwd(f, argnums=1)(t_eval, inputs)\n",
    "print(f\"Jacobian forward method ran in {time.time() - t_start:0.3} secs\")\n",
    "print(out)\n",
    "\n",
    "# Calculate Jacobian matrix (via backward autodiff)\n",
    "t_start = time.time()\n",
    "out = jax.jacrev(f, argnums=1)(t_eval, inputs)\n",
    "print(f\"\\nJacobian reverse method ran in {time.time() - t_start:0.3} secs\")\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To extract the relevant data vector from the above expression, we can again make use of the `get_var()` helper function, which can also take numpy arrays as input, for example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-0.13643792 -0.16400861 -0.17630142 -0.18509421 -0.19273301 -0.19993145\n",
      " -0.20692727 -0.21380043 -0.22057579 -0.2272616 ]\n"
     ]
    }
   ],
   "source": [
    "# Isolate the derivate of Voltage with respect to the Current function:\n",
    "out = jax.jacfwd(f, argnums=1)(t_eval, inputs)\n",
    "data = jax_solver.get_var(out[\"Current function [A]\"], \"Voltage [V]\")\n",
    "print(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2tC9Bp_g9mOp"
   },
   "source": [
    "The gradient (`grad`) function on the other hand requires the underlying function to return a scalar value. The function must therefore be called separately for each time sample, and can only be evaluted for one output variable at a time. We can obey these restrictions with our JAX expression `f` through use of the `get_var` and `vmap` functions (the latter of which provides vector-mapping over time)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "sJjWUIcG9lWa"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Gradient method ran in 0.105 secs\n",
      "{'Current function [A]': Array([-0.13643792, -0.16400861, -0.17630142, -0.18509421, -0.19273301,\n",
      "       -0.19993145, -0.20692727, -0.21380043, -0.22057579, -0.2272616 ],      dtype=float64), 'Separator porosity': Array([0.00579553, 0.00797   , 0.0095281 , 0.01024868, 0.01053737,\n",
      "       0.0106461 , 0.01068649, 0.01070164, 0.01070816, 0.01071172],      dtype=float64)}\n"
     ]
    }
   ],
   "source": [
    "# Example evaluation using the `grad` function\n",
    "t_start = time.time()\n",
    "data = jax.vmap(\n",
    "    jax.grad(\n",
    "        jax_solver.get_var(\"Voltage [V]\"),\n",
    "        argnums=1,  # take derivative with respect to `inputs`\n",
    "    ),\n",
    "    in_axes=(0, None),  # map time over the 0th dimension and do not map inputs\n",
    ")(t_eval, inputs)\n",
    "print(f\"Gradient method ran in {time.time() - t_start:0.3} secs\")\n",
    "print(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "G0bo1TPL-ZAM"
   },
   "source": [
    "## A use-case example\n",
    "\n",
    "As a use-case example, consider a fitting procedure where we want to compare simulation data against some experimental data. We achieve this by computing the sum-of-squared error (SEE) between the two. Many fitting procedures will converge more quickly (with fewer iterations) if both the value *and gradient* of the SSE function are provided. By making use of JAX-expressions we can derive these effortlessly.\n",
    "\n",
    "*Note*: We do not need to map over time when calling `value_and_grad` in this example as the `sse` function returns a scalar (despite taking vector inputs)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "56NPH9sZ-ZFq"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Value and gradient computed in 0.095 secs\n",
      "SSE value:  0.0020846163677995366\n",
      "SSE gradient (wrt each input):  {'Current function [A]': array(-0.05775429), 'Separator porosity': array(0.00146983)}\n"
     ]
    }
   ],
   "source": [
    "# Simulate some experimental data using our original parameter settings\n",
    "data = sim[\"Voltage [V]\"](t_eval)\n",
    "\n",
    "\n",
    "# Sum-of-squared errors\n",
    "def sse(t, inputs):\n",
    "    modelled = jax_solver.get_var(\"Voltage [V]\")(t_eval, inputs)\n",
    "    return jnp.sum((modelled - data) ** 2)\n",
    "\n",
    "\n",
    "# Provide some predicted model inputs (these could come from a fitting procedure)\n",
    "inputs_pred = {\n",
    "    \"Current function [A]\": 0.150,\n",
    "    \"Separator porosity\": 0.333,\n",
    "}\n",
    "\n",
    "# Get the value and gradient of the SSE function\n",
    "t_start = time.time()\n",
    "value, gradient = jax.value_and_grad(sse, argnums=1)(t_eval, inputs_pred)\n",
    "print(f\"Value and gradient computed in {time.time() - t_start:0.3} secs\")\n",
    "print(\"SSE value: \", value)\n",
    "print(\"SSE gradient (wrt each input): \", gradient)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Nj0ylzso8Yu_"
   },
   "source": [
    "All of the above expressions can be JIT compiled (onto CPU) by using the `jax.jit` directive. Practically, this provides a wrap-around back to the Python interface of the IDAKLU Solver, so is only provided to afford maximum downstream compatibility (where JIT may be called outside of the user's control)."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
